import alg
import numpy as np
import time
import random
import math
from scipy.optimize import fsolve

# S-OWA
def OLSMinus(
        T: int,
        K: int,
        gamma: float,
        consumptionArray: np.ndarray, #L*W
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
        granularity: float
) -> [float, float] :
    level = 2
    theta = 0
    L=1
    W = 0
    h = np.array([0 for w in range(int(1/granularity)+1)])
    accumulatedReward = 0
    averageRunTime = 0
    lastDecision = False
    consumptionArray = np.array([consumptionArray[2]])
    gamma=3*gamma
    for t in range(T):
        startTime = time.time()
        if t==0:
            process = True
            dice = random.random()
            l = math.floor(dice/gamma)
            if l>0:
                process = False
        else:
            process = False
            if W <= theta and (K-W)>=1:
                process = True

        if process:
            e = consumptionOracle[level,t]
            W = W+e
            accumulatedReward+=valueArray[level,t]
        elif lastDecision:
            accumulatedReward += doNothingValue[level, t-1]

        h = alg.update_hListTrace(
            granularity=granularity,
            theta=theta,
            h=h,
            g=consumptionArray,
            t=t,
            L=L,
            gamma=gamma
        )
        theta = 1
        threshold = 0
        for w in range(len(h)):
            threshold += h[w]
            if threshold>=L*gamma:
                theta = w*granularity
                break
        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime

def Random(
        T: int,
        L: int,
        K: int,
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
) -> [float, float] :
    W = 0
    accumulatedReward = 0
    averageRunTime = 0
    lastDecision = False
    lastLevel = 0
    for t in range(T):
        startTime = time.time()
        level = L

        process = False
        if (K-W)>=1:
            process = True
            dice = random.random()
            level = math.floor(dice/(1/(L+1)))
            if level >L-1:
                process = False

        if process:
            e = consumptionOracle[level,t]
            W = W+e
            accumulatedReward+=valueArray[level,t]
        elif lastDecision:
            accumulatedReward += doNothingValue[lastLevel, t-1]

        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime

def Greedy(
        T: int,
        K: int,
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
) -> [float, float] :
    W = 0
    accumulatedReward = 0
    averageRunTime = 0
    lastDecision = False
    lastLevel = 0
    for t in range(T):
        startTime = time.time()
        level = 2

        process = False
        if (K-W)>=1:
            process = True


        if process:
            e = consumptionOracle[level,t]
            W = W+e
            accumulatedReward+=valueArray[level,t]
        elif lastDecision:
            accumulatedReward += doNothingValue[level, t-1]

        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime

def Adaptive(
        T: int,
        K: int,
        L: int,
        KList: np.ndarray,
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
) -> [float, float] :
    W = 0
    threshold = 1
    accumulatedReward = 0
    averageRunTime = 0
    lastDecision = False
    lastLevel = 0
    processed = 0
    for t in range(T):
        startTime = time.time()
        level = L

        process = False
        if t==0:
            process = True
            dice = random.random()
            level = min(math.floor(dice/(1/L)),L-1)
        elif (K-W)>=1 and processed/t<threshold:
            process = True
            dice = random.random()
            level = min(math.floor(dice/(1/L)),L-1)


        if process:
            e = consumptionOracle[level,t]
            W = W+e
            accumulatedReward+=valueArray[level,t]
            processed+=1
        elif lastDecision:
            accumulatedReward += doNothingValue[lastLevel, t-1]

        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime

def MAB(
        T: int,
        K: int,
        L: int,
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
) -> [float, float] :
    W = 0
    accumulatedReward = 0
    averageRunTime = 0
    lastDecision = False
    lastLevel = L
    levelTotalReward = [0,0,0,1]
    levelTotalConsumption = [0,0,0,0.01]
    levelTotalTasks = [0,0,0,0]
    exploration = True
    count = 0

    for t in range(T):
        startTime = time.time()
        process = True
        if exploration:
            level = int(count%(L+1))
            if level == L:
                process = False
            count +=1
        else:
            level = 0
            efficiency = 0
            for l in range(L+1):
                if levelTotalReward[l]/levelTotalConsumption[l]>efficiency:
                    level = l
                    efficiency = levelTotalReward[l]/levelTotalConsumption[l]
            if level == L:
                process = False
            count+=1

        if (K-W)<1:
            process = False

        if process:
            e = consumptionOracle[level,t]
            W = W+e
            accumulatedReward+=valueArray[level,t]
            levelTotalTasks[level]+=1
            levelTotalConsumption[level]+=e
            levelTotalReward[level]+=valueArray[level,t]
        elif lastDecision:
            accumulatedReward += doNothingValue[lastLevel, t-1]
            levelTotalTasks[L]+=1

            levelTotalReward[L]+=doNothingValue[lastLevel, t-1]


        if count==5*(L+1):
            count = 0
            exploration = not exploration
        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime

def GOK(
        T: int,
        K: int,
        L: int,
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
) -> [float, float] :
    W = 0
    accumulatedReward = 0
    averageRunTime = 0
    lastDecision = False
    lastLevel = L
    level = 0

    for t in range(T):
        startTime = time.time()
        process = True
        level = 0

        if (K-W)<1:
            process = False

        if process:
            e = consumptionOracle[level,t]
            W = W+e
            accumulatedReward+=valueArray[level,t]

        elif lastDecision:
            accumulatedReward += doNothingValue[lastLevel, t-1]

        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime

def MPC(
        T: int,
        K: int,
        L: int,
        maxValue: np.ndarray,
        avgConsumption: np.ndarray,
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
) -> [float, float] :
    N=30
    W = 0
    accumulatedReward = 0
    averageRunTime = 0
    lastDecision = False
    lastLevel = L
    level = 0
    levelReward = [[0 for _ in range(N)] for _ in range(L)]
    levelConsumption = [[0 for _ in range(N)] for _ in range(L)]

    for t in range(T):
        startTime = time.time()
        process = True

        count=0
        avgEfficiency = []
        for l in range(L):
            if sum(levelConsumption[l][len(levelConsumption[l])-N:len(levelConsumption[l])])==0:
                avgEfficiency.append(maxValue[l]/avgConsumption[l])
                count+=1
            else:
                avgEff = sum(levelReward[l][len(levelReward[l])-N:len(levelReward[l])])/sum(levelConsumption[l][len(levelConsumption[l])-N:len(levelConsumption[l])])
                avgEfficiency.append(avgEff)

        if count==L:
            dice = random.random()
            level = min(math.floor(dice/(1/(L+1))),L-1)
            if level >L-1:
                process = False
        else:
            level = 0
            maxEff = 0
            for l in range(L):
                if avgEfficiency[l]>maxEff:
                    level = l
                    maxEff = avgEfficiency[l]

        if (K-W)<1:
            process = False

        if process:
            e = consumptionOracle[level,t]
            W = W+e
            accumulatedReward+=valueArray[level,t]
            for l in range(L):
                if l==level:
                    levelReward[l].append(valueArray[level,t])
                    levelConsumption[l].append(e)
                else:
                    levelReward[l].append(0)
                    levelConsumption[l].append(0)

        elif lastDecision:
            accumulatedReward += doNothingValue[lastLevel, t-1]

        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime

# A-OWA
def AverageMagician(
        T: int,
        K: int,
        L: int,
        gamma: float,
        consumptionArray: np.ndarray, #L*W
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
        granularity: float
) -> [float, float] :
    gamma = gamma/L
    level = 2
    thetaList = [0 for _ in range(L)]
    W = 0
    hList = np.array([[0 for w in range(int(1/granularity)+1)] for _ in range(L)])
    accumulatedReward = 0
    averageRunTime = 0
    lastDecision = False
    for t in range(T):
        startTime = time.time()
        level = L
        if t==0:
            process = True
            dice = random.random()
            level = math.floor(dice/(1/L))
            if level>=L:
                process = False
        else:
            process = False
            dice = random.random()
            level = math.floor(dice/(1/L))
            if level>L:
                process = False
            elif W <= thetaList[level] and (K-W)>=1:
                process = True

        if process:
            e = consumptionOracle[level,t]
            W = W+e
            accumulatedReward+=valueArray[level,t]
        elif lastDecision:
            accumulatedReward += doNothingValue[level, t-1]

        hNewList = []
        for i in range(L):
            hNewList.append(alg.update_hListTrace(
                granularity=granularity,
                theta=thetaList[i],
                h=hList[i],
                g=np.array([consumptionArray[i]]),
                t=t,
                L=1,
                gamma=gamma
            ))
        hList=hNewList.copy()
        theta = 1
        for i in range(L):
            threshold = 0
            thetaList[i]=0
            for w in range(len(hList[i])):
                threshold += hList[i][w]
                if threshold>=L*gamma:
                    thetaList[i] = w*granularity
                    break
        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime


def UBC(
        T: int,
        K: int,
        L: int,
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
) -> [float, float] :
    c = 1.75
    W = 0
    accumulatedReward = 0
    averageRunTime = 0
    lastDecision = False
    lastLevel = L
    levelTotalReward = [0,0,0,0]
    levelTotalTasks = [0,0,0,0]

    for t in range(T):
        startTime = time.time()
        process = True
        QList = []
        UBCList = []
        if t<=L:
            level = t
            if level == L:
                process = False
        else:
            QList = [levelTotalReward[i]/levelTotalTasks[i] for i in range(L+1)]
            UBCList = [QList[i]+c*math.pow(math.log(t)/levelTotalTasks[i],1/2) for i in range(L+1)]
            level = UBCList.index(max(UBCList))
            if level == L:
                process = False

        if (K-W)<1:
            process = False

        if process:
            e = consumptionOracle[level,t]
            W = W+e
            accumulatedReward+=valueArray[level,t]
            levelTotalTasks[level]+=1
            levelTotalReward[level]+=valueArray[level,t]
        elif lastDecision:
            accumulatedReward += doNothingValue[lastLevel, t-1]
            levelTotalTasks[L]+=1

            levelTotalReward[L]+=doNothingValue[lastLevel, t-1]

        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime


def OTP(
        K: float,    #g^k_max
        T: int, #h_k
        L:int,
        estimateT: int,
        consumptionArray: np.ndarray,
        consumptionOracle: np.ndarray, #L*T
        valueArray: np.ndarray, #L*T
        doNothingValue: np.ndarray,	#L*T, need to use t-1
        granularity:float,
        maxValue: np.ndarray
) -> [float, float] :
    xAxis = np.arange(0, 1+granularity, granularity)
    averageComsumtion = [1 for i in range(L)]
    UList = [maxValue[i]/averageComsumtion[i] for i in range(L)]
    parameterL = min(UList)/2
    parameterU = max(UList)
    lambertW = fsolve(LambertW, 0.1, args=((1/math.e)*(parameterL-parameterU)/parameterU,))
    theta = parameterL*K*(lambertW[0]+1)
    phi = [parameterL*K+(theta-parameterL*K)*math.pow(math.e,theta*i/(parameterL*K)) for i in xAxis]
    phi = np.array(phi)
    y=0

    W=0
    lastLevel = L
    lastDecision =False
    accumulatedReward=0
    averageRunTime = 0
    for t in range(T):
        startTime = time.time()

        process = False
        level = L
        aList = maxValue/2
        gList = averageComsumtion.copy()
        jList = [integration(K,aList[i],gList[i],y,phi,granularity) for i in range(L)]
        j = jList.index(max(jList))
        if y+(estimateT-1-t)*max(averageComsumtion)/K<=1:
            j=aList.tolist().index(np.max(aList))
            level = j
            process = True
        elif jList[j]>=0:
            level = j
            process = True

        if process:

            e = consumptionOracle[level,t]
            W = W+e
            y=y+e/K
            accumulatedReward+=valueArray[level,t]
        elif lastDecision:
            accumulatedReward += doNothingValue[lastLevel, t-1]


        endTime = time.time()
        averageRunTime += (endTime-startTime)
        lastDecision = process
        lastLevel = level
    averageRunTime /= T
    return accumulatedReward, averageRunTime


def integration(
        K:float,
        a:float,
        g:float,
        y:float,
        phi: np.ndarray,
        granularity: float
) -> float:
    start = max(round(y/granularity),0)
    end = min(10,round(g/K+y))
    return a-np.sum(phi[start:end+1])


def LambertW(y, x):
    return y * np.exp(y) - x